{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NOoAYv5ThsHg"
      },
      "source": [
        "# Three Ways to Build a Text Classifier with Cohere.ipynb"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AVf0QmYJiJ22"
      },
      "source": [
        "With LLMs, instead of having to prepare thousands of training data points, you can get up and running with just a handful of examples, called *few-shot* classification. Having said that, you probably want to have a certain level of control over how you train a classifier, and especially, how to get the best performance out of a model. For example, if you do happen to have a large dataset at your disposal, you will want to make full use of it when training a classifier. With the Cohere API, we want to give this flexibility to developers."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eiOMX1UaSN2V"
      },
      "source": [
        "***Read the accompanying [blog post here.](https://txt.cohere.ai/classify-three-options/)***"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "![[Blog] Three Ways to Build a Text Classifier with the Cohere API](https://github.com/cohere-ai/notebooks/raw/main/notebooks/images/classify-three-options/classify-options-feat.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "29uwe-jzJ9rh"
      },
      "outputs": [],
      "source": [
        "# TODO: upgrade to \"cohere>5\"",
"! pip install \"cohere<5\" > /dev/null"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "y9-RyLu7KHII"
      },
      "outputs": [],
      "source": [
        "# Import the required modules\n",
        "import cohere\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.metrics import accuracy_score\n",
        "from sklearn.metrics import f1_score"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "AW4RsVSE4j74"
      },
      "outputs": [],
      "source": [
        "# Set up the Cohere client\n",
        "api_key = 'apikey' # Paste your API key here. Remember to not share it publicly \n",
        "co = cohere.Client(api_key)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "laIkCcKa40PR"
      },
      "source": [
        "# Prepare the Dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Fil-WkzipV_"
      },
      "source": [
        "We'll use a subset of the [Airline Travel Information System (ATIS)](https://www.kaggle.com/datasets/hassanamin/atis-airlinetravelinformationsystem?select=atis_intents_train.csv) intent classification dataset. [[Source](https://aclanthology.org/H90-1021/)]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 60,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "gI7wMIdOrbiK",
        "outputId": "ebbaf0c0-60a2-487f-d1b3-7ec34175c35e"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-1bbf40a0-666e-4c4d-9969-39f995fae4c6\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>query</th>\n",
              "      <th>intent</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>i want to fly from boston at 838 am and arriv...</td>\n",
              "      <td>atis_flight</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>what flights are available from pittsburgh to...</td>\n",
              "      <td>atis_flight</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>what is the arrival time in san francisco for...</td>\n",
              "      <td>atis_flight_time</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>cheapest airfare from tacoma to orlando</td>\n",
              "      <td>atis_airfare</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>round trip fares from pittsburgh to philadelp...</td>\n",
              "      <td>atis_airfare</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-1bbf40a0-666e-4c4d-9969-39f995fae4c6')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-1bbf40a0-666e-4c4d-9969-39f995fae4c6 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-1bbf40a0-666e-4c4d-9969-39f995fae4c6');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "                                               query            intent\n",
              "0   i want to fly from boston at 838 am and arriv...       atis_flight\n",
              "1   what flights are available from pittsburgh to...       atis_flight\n",
              "2   what is the arrival time in san francisco for...  atis_flight_time\n",
              "3            cheapest airfare from tacoma to orlando      atis_airfare\n",
              "4   round trip fares from pittsburgh to philadelp...      atis_airfare"
            ]
          },
          "execution_count": 60,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Load the dataset to a dataframe\n",
        "df = pd.read_csv('https://raw.githubusercontent.com/cohere-ai/notebooks/main/notebooks/data/atis_subset.csv',names=['query','intent'])\n",
        "df.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 61,
      "metadata": {
        "id": "cM-2KGlQe80t"
      },
      "outputs": [],
      "source": [
        "# Split the dataset into training and test portions\n",
        "# Training = For use in Sections 2 and 3\n",
        "# Test = For evaluating the classifier performance\n",
        "X, y = df[\"query\"], df[\"intent\"]\n",
        "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=200, random_state=21)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 62,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-J-T5ZKP5s0t",
        "outputId": "41c946b1-9f46-44d0-bb49-31f348d252dd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "['atis_flight', 'atis_airfare', 'atis_ground_service', 'atis_flight_time', 'atis_airline', 'atis_quantity', 'atis_abbreviation', 'atis_aircraft']\n"
          ]
        }
      ],
      "source": [
        "# View the list of all available categories\n",
        "intents = y_train.unique().tolist()\n",
        "print(intents)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "79RhgMQr452c"
      },
      "source": [
        "# 1 - Few-shot classification with the Classify endpoint"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T-22ncqsi8sH"
      },
      "source": [
        "Few-shot here means we just need to supply a few examples per class and have a decent classifier working. With Cohere’s Classify endpoint, the ‘training’ dataset is referred to as *examples*. The minimum number of examples per class is two, where each example consists of a text (in our case, the `query`), and a label (in our case, the `label`). More examples are better though, and in our case, we'll use six examples per class."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yiHoK7A34-iJ"
      },
      "source": [
        "## Prepare the examples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 63,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IzWMEEb4LmRz",
        "outputId": "5a671716-d390-473f-856b-5c28453a4ece"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Number of examples per class: 6\n",
            "Number of classes: 8\n",
            "Total number of examples: 48\n"
          ]
        }
      ],
      "source": [
        "# Set the number of examples per category\n",
        "EX_PER_CAT = 6\n",
        "\n",
        "# Create list of examples containing texts and labels - sample from the dataset\n",
        "ex_texts, ex_labels = [], []\n",
        "for intent in intents:\n",
        "  y_temp = y_train[y_train == intent]\n",
        "  sample_indexes = y_temp.sample(n=EX_PER_CAT, random_state=42).index\n",
        "  ex_texts += X_train[sample_indexes].tolist()\n",
        "  ex_labels += y_train[sample_indexes].tolist()\n",
        "\n",
        "print(f'Number of examples per class: {EX_PER_CAT}')\n",
        "print(f'Number of classes: {len(intents)}')\n",
        "print(f'Total number of examples: {len(ex_texts)}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "29oLOWrY6TnM"
      },
      "source": [
        "## Get classifications via the Classify endpoint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WujlVX-ERNW4"
      },
      "outputs": [],
      "source": [
        "# Collate the examples via the Example module\n",
        "from cohere.responses.classify import Example\n",
        "\n",
        "examples = list()\n",
        "for txt, lbl in zip(ex_texts,ex_labels):\n",
        "  examples.append(Example(txt,lbl))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 86,
      "metadata": {
        "id": "NXCHB13-rbZZ"
      },
      "outputs": [],
      "source": [
        "# Generate classification predictions on the test dataset\n",
        "\n",
        "# Classification function\n",
        "def classify_text(texts, examples):\n",
        "    classifications = co.classify(\n",
        "        inputs=texts,\n",
        "        examples=examples\n",
        "    )\n",
        "    return [c.predictions[0] for c in classifications]\n",
        "\n",
        "# Create batches of texts and classify them\n",
        "BATCH_SIZE = 90 # The API accepts a maximum of 96 inputs\n",
        "y_pred = []\n",
        "for i in range(0, len(X_test), BATCH_SIZE):\n",
        "    batch_texts = X_test[i:i+BATCH_SIZE].tolist()\n",
        "    y_pred.extend(classify_text(batch_texts, examples))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 88,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mVxRlIzR74Xz",
        "outputId": "86b085b0-b7cd-4994-e6cc-49b14e93c4d5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Accuracy: 83.00\n",
            "F1-score: 84.66\n"
          ]
        }
      ],
      "source": [
        "# Compute metrics on the test dataset\n",
        "accuracy = accuracy_score(y_test, y_pred)\n",
        "f1 = f1_score(y_test, y_pred, average='weighted')\n",
        "\n",
        "print(f'Accuracy: {100*accuracy:.2f}')\n",
        "print(f'F1-score: {100*f1:.2f}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-60rd2N1up2i"
      },
      "source": [
        "# 2 - Build your own classifier with the Embed endpoint"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NbVkkSLXjaim"
      },
      "source": [
        "In this section, we’ll look at how we can use the Embed endpoint to build a classifier. We are going to build a classification model using these embeddings as inputs. For this, we’ll use the Support Vector Machine (SVM) algorithm."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YioSofGK7OAU"
      },
      "source": [
        "## Generate embeddings for the input text"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 89,
      "metadata": {
        "id": "SAbu-iuGZg_3"
      },
      "outputs": [],
      "source": [
        "# Get embeddings\n",
        "def embed_text(text):\n",
        "  output = co.embed(\n",
        "                model='embed-english-v3.0',\n",
        "                input_type=\"classification\",\n",
        "                texts=text)\n",
        "  return output.embeddings\n",
        "\n",
        "# Embed and prepare the inputs\n",
        "X_train_emb = np.array(embed_text(X_train.tolist()))\n",
        "X_test_emb = np.array(embed_text(X_test.tolist()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yqt3Nutm7Gk3"
      },
      "source": [
        "## Get classifications via the SVM algorithm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 95,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-eS6_R6zZiuO",
        "outputId": "a618bb06-7ca1-4b03-fd6f-0a769d1e44a9"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "SVC(class_weight='balanced')"
            ]
          },
          "execution_count": 95,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Import modules\n",
        "from sklearn.svm import SVC\n",
        "from sklearn import preprocessing\n",
        "\n",
        "# Prepare the labels\n",
        "le = preprocessing.LabelEncoder()\n",
        "le.fit(y_train)\n",
        "y_train_le = le.transform(y_train)\n",
        "y_test_le = le.transform(y_test)\n",
        "\n",
        "# Initialize the model\n",
        "svm_classifier = SVC(class_weight='balanced')\n",
        "\n",
        "# Fit the training dataset to the model\n",
        "svm_classifier.fit(X_train_emb, y_train_le)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 96,
      "metadata": {
        "id": "uvY13RjmAXuX"
      },
      "outputs": [],
      "source": [
        "# Generate classification predictions on the test dataset\n",
        "y_pred_le = svm_classifier.predict(X_test_emb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 97,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0EwE3VDZem-q",
        "outputId": "faafd49d-ca05-403a-ac02-9cb5d4a8c58b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Accuracy: 91.50\n",
            "F1-score: 91.01\n"
          ]
        }
      ],
      "source": [
        "# Compute metrics on the test dataset\n",
        "accuracy = accuracy_score(y_test_le, y_pred_le)\n",
        "f1 = f1_score(y_test_le, y_pred_le, average='weighted')\n",
        "\n",
        "print(f'Accuracy: {100*accuracy:.2f}')\n",
        "print(f'F1-score: {100*f1:.2f}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vNHwmhIzr8hv"
      },
      "source": [
        "# 3 - Finetuning a model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DkNhE3Kkj4XT"
      },
      "source": [
        "In this section, we build a custom model that’s finetuned to excel at a specific task, and potentially outperforming the previous two approaches we have seen."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jpgdrMbuPrZP"
      },
      "source": [
        "## Prepare dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 94,
      "metadata": {
        "id": "o95wRbeimrXC"
      },
      "outputs": [],
      "source": [
        "# Download the training dataset for finetuning\n",
        "df_train = pd.concat([X_train, y_train],axis=1)\n",
        "df_train.to_csv(\"atis_finetune.csv\", index=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3cIwcAUiRi5e"
      },
      "source": [
        "## Create a finetuned model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qnFRzY19nvrT"
      },
      "source": [
        "Creating the finetune is done is the Playground. Refer to [this guide](https://docs.cohere.ai/finetuning-representation-models) for the finetuning steps."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DJDDq9kdCMvH"
      },
      "source": [
        "## Get classifications via the Classify endpoint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 98,
      "metadata": {
        "id": "Mq8sj5Fn0OuE"
      },
      "outputs": [],
      "source": [
        "# Generate classification predictions on the test dataset using the finetuned model\n",
        "\n",
        "# Classification function\n",
        "def classify_text_finetune(texts, examples):\n",
        "    classifications = co.classify(\n",
        "        model='eeba7d8c-61bd-42cd-a6b5-e31db27403cc-ft', \n",
        "        inputs=texts,\n",
        "        examples=examples\n",
        "    )\n",
        "    return [c.predictions[0] for c in classifications]\n",
        "\n",
        "# Create batches of texts and classify them\n",
        "BATCH_SIZE = 90 # The API accepts a maximum of 96 inputs\n",
        "y_pred = []\n",
        "for i in range(0, len(X_test), BATCH_SIZE):\n",
        "    batch_texts = X_test[i:i+BATCH_SIZE].tolist()\n",
        "    y_pred.extend(classify_text_finetune(batch_texts, examples))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 100,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qJ961ixH9kbf",
        "outputId": "610b9ae4-3c62-4315-a572-ec1c81167eec"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Accuracy: 94.50\n",
            "F1-score: 94.53\n"
          ]
        }
      ],
      "source": [
        "# Compute metrics on the test dataset\n",
        "accuracy = accuracy_score(y_test, y_pred)\n",
        "f1 = f1_score(y_test, y_pred, average='weighted')\n",
        "\n",
        "print(f'Accuracy: {100*accuracy:.2f}')\n",
        "print(f'F1-score: {100*f1:.2f}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aUOq3WzGkd9g"
      },
      "source": [
        "We have now seen how the different options compare performance-wise. And crucially, what’s important to note is the level of control that you have when working with the Classify endpoint."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3.10.0 64-bit ('3.10.0')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.10.0"
    },
    "vscode": {
      "interpreter": {
        "hash": "1fb8019e3560b882083e525615cf48e713d3a7345a15eb723d805e91aa410aac"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
